from functools import partial
import re

import numpy as np
import pandas as pd
from scipy.spatial.distance import cosine
import statsmodels.stats.api as sm


def prep_qrys(qrys):
    qrys = qrys.sort_values(by="timestamp")
    # drop consecutive identical queries (interactions with image results)
    qrys = qrys.groupby("user_id").apply(lambda df: df.loc[df.qry.shift(-1) != df.qry])
    qrys = qrys.reset_index(drop=True)
    return qrys


def load_vecs():
    qrys = np.load("qrys.npy", allow_pickle=True)
    vecs = np.load("embeddings.npy", allow_pickle=True)
    vec_dict = {qry: vec for qry, vec in zip(qrys, vecs)}
    return vec_dict


def add_sessions(df, vec_dict):
    df["time_diff"] = df.timestamp.diff().dt.total_seconds()
    qry_vecs = [vec_dict[qry] for qry in df.qry]
    df["cos_sim"] = [0] + [
        1 - cosine(q, q_next) for q, q_next in zip(qry_vecs[:-1], qry_vecs[1:])
    ]
    return df


def get_sessions_time(df, thresh=30):
    df[f"session_time_{thresh}"] = (df["time_diff"] >= thresh * 60).cumsum()
    return df


def get_sessions_sim(df, thresh=0.7):
    df[f"session_sim_{thresh}"] = (df["cos_sim"] < thresh).cumsum()
    return df


def add_session_idx(df):
    df["session_idx"] = range(len(df))
    return df


def all_sessions(data, sesh_func, thresh_list):
    sessions = []
    for thresh in thresh_list:
        sesh = (
            data.groupby("user_id", group_keys=False)
            .apply(partial(sesh_func, thresh=thresh))
            .reset_index(drop=True)
        )
        if sesh_func == get_sessions_time:
            s_col = f"session_time_{thresh}"
            sesh = sesh.groupby(["user_id", s_col], group_keys=False).apply(
                add_session_idx
            )
        elif sesh_func == get_sessions_sim:
            s_col = f"session_sim_{thresh}"
            short_sesh = sesh.groupby(["user_id", s_col], group_keys=False).filter(
                lambda df: len(df) == 1
            )
            short_sesh["session_idx"] = 0
            sesh_merge = sesh.merge(
                short_sesh[["user_id", s_col, "url"]], on=["user_id", s_col], how="left"
            )
            long_sesh = sesh_merge[sesh_merge.url_y.isna()]
            long_sesh = long_sesh.drop(columns="url_y").rename(columns={"url_x": "url"})
            long_sesh = long_sesh.groupby(["user_id", s_col], group_keys=False).apply(
                add_session_idx
            )
            sesh = pd.concat((short_sesh, long_sesh))
        sessions.append(sesh)
    return sessions


def compare(data_list, thresh_list):
    first_q = [x[x.category.notna() & (x.session_idx == 0)] for x in data_list]
    refined_q = [x[x.category.notna() & (x.session_idx > 0)] for x in data_list]

    data = []
    for fq, rq, t in zip(first_q, refined_q, thresh_list):
        lo, hi = sm.confint_proportions_2indep(
            rq.word_race_gender.notna().sum(),
            len(rq),
            fq.word_race_gender.notna().sum(),
            len(fq),
        )
        data.append({"lo": lo, "hi": hi, "thresh": t})

    return pd.DataFrame(data)


def check_word_match(row, dem_res, cat):
    word = row.word_race_gender
    res = dem_res[cat]
    matches = sum([re.search(r, word) is not None for r in res])
    if matches > 0:
        return 1
    else:
        return 0


def prep_r_data(data, fp_r_data="qry_data.csv"):

    match_words = {
        "Male": ["man", "men", "male", "boy"],
        "Female": ["woman", "women", "female", "girl"],
        "Black": ["black"],
        "White": ["white"],
    }
    dem_res = {
        cat: [re.compile(f"(^|\s){word}($|\s)") for word in words]
        for cat, words in match_words.items()
    }

    for cat in match_words.keys():
        data[f"word_{cat.replace(' ', '_')}"] = data.apply(
            partial(check_word_match, dem_res=dem_res, cat=cat), axis=1
        )

    data[f"is_Male"] = data.gender == "Male"
    data[f"is_Female"] = data.gender == "Female"
    for cat in data.race.unique():
        data[f"is_{cat.replace(' ', '_')}"] = data.race == cat

    cols = [
        "word_Female",
        "word_Male",
        "word_Black",
        "word_White",
        "is_Female",
        "is_Male",
        "is_Black",
        "is_Hispanic",
        "is_Asian",
        "is_Two_or_more_races",
        "user_id",
    ]
    data[cols].to_csv(fp_r_data, index=False)


def main(fp_input):
    qrys = pd.read_csv(fp_input)
    qrys = prep_qrys(qrys)
    vec_dict = load_vecs()
    data = qrys.groupby("user_id", group_keys=False).apply(
        partial(add_sessions, vec_dict=vec_dict)
    )

    prep_r_data(data)

    thresh_list = np.arange(0.1, 1, 0.1)
    qry_sim = all_sessions(data, get_sessions_sim, thresh_list)
    compare(qry_sim, thresh_list)


if __name__ == "__main__":
    main()
